#include <zMat.hpp>
#include <UnitTest++.h>


namespace zzz{
  struct ClassA{
    int x;
    void operator =(const ClassA &other) {x=other.x+1;}
  };
  template<>
  class IOObject<ClassA> {
  public:
    static void CopyData(ClassA* dst, const ClassA* src, zsize size) {
      memcpy(dst,src,sizeof(ClassA)*size);
    }
  };
  struct ClassB{
    int x;
    void operator =(const ClassB &other) {x=other.x+1;}
  };
}

using namespace zzz;

SUITE(ZMATRIXTEST)
{
  TEST(CONSTRUCT)
  {
    zMatrix<double> v1(3,3);
    CHECK_EQUAL(3,v1.Size(0));
    CHECK_EQUAL(3,v1.Size(1));

    zMatrix<double> v2(Ones<double>(3,3));
    for (int r=0; r<3; r++) for (int c=0; c<3; c++)
    {
      CHECK_EQUAL(1,v2(r,c));
    }

    zMatrix<double> v3(Zeros<double>(3,3));
    for (int r=0; r<3; r++) for (int c=0; c<3; c++)
    {
      CHECK_EQUAL(0,v3(r,c));
    }

    zMatrix<double> v4(Diag<double>(Ones<double>(3,1)));
    for (int r=0; r<3; r++) for (int c=0; c<3; c++)
    {
      if (r==c) CHECK_EQUAL(1,v4(r,c));
      else CHECK_EQUAL(0,v4(r,c));
    }

    //correct construct, not simply memcpy
    v1(0,0)=2;
    CHECK_EQUAL(1,v2(0,0));
    
    zMatrix<double> v5(3);
    CHECK_EQUAL(3,v5.Size(0));
    CHECK_EQUAL(1,v5.Size(1));
  }

  TEST(COPY)
  {
    zMatrix<double> v(Ones<double>(3,3));
    zMatrix<double> v2;
    v2=v;
    for (int r=0; r<3; r++) for (int c=0; c<3; c++)
    {
      CHECK_EQUAL(1,v2(r,c));
    }
  }

  TEST(SMART_COPY)
  {
    zMatrix<ClassA> a(1,1);
    a(0,0).x=0;
    zMatrix<ClassA> a2(a);
    CHECK_EQUAL(0,a2(0,0).x);
    zMatrix<ClassB> b(1,1);
    b(0,0).x=0;
    zMatrix<ClassB> b2(b);
    CHECK_EQUAL(1,b2(0,0).x);
  }

  TEST(EQUALS)
  {
    zMatrix<double> v1(Ones<double>(3,3));
    zMatrix<double> v2(Ones<double>(3,3));
    CHECK(v1==v2);
  }

  TEST(CONSTANT_EXPRESSION)
  {
    zMatrix<double> v1(Ones<double>(3,3));  //v1=1
    zMatrix<double> v2(v1*2.0);          //v2=2
    for (int r=0; r<3; r++) for (int c=0; c<3; c++)
    {
      CHECK_EQUAL(2,v2(r,c));
    }

    zMatrix<double> v3(Ones<double>(3,3)*3.0);  //v3=3
    zMatrix<double> v4(v2+v3);              //v4=5
    for (int r=0; r<3; r++) for (int c=0; c<3; c++)
    {
      CHECK_EQUAL(2+3,v4(r,c));
    }
  }

  TEST(MATRIX_EXPRESSION)
  {
    zMatrix<double> v1(Rand(3,3));
    zMatrix<double> v2(Rand(3,3)*2.0);
    zMatrix<double> v3(Rand(3,3)*3.0);
    zMatrix<double> v4(Rand(3,3)*4.0);
    zMatrix<double> v5(v1+DotTimes(v2,v3)-v4);
    for (int r=0; r<3; r++) for (int c=0; c<3; c++)
    {
      CHECK_CLOSE(v1(r,c)+v2(r,c)*v3(r,c)-v4(r,c), v5(r,c), EPSILON);
    }
  }

  TEST(MATRIX_COMPARE)
  {
    zMatrix<double> v1(Rand(3,3));
    zMatrix<double> v2(Rand(3,3));
    zMatrix<double> v3(v1<v2);
    for (zuint i=0; i<v3.size(); i++)
      if(v1[i]<v2[i]) CHECK(v3[i]==1);
      else CHECK(v3[i]==0);

  }

  TEST(FUNCTION_EXPRESSION)
  {
    zMatrix<double> v1(Ones<double>(3,3));    //v1=1
    zMatrix<double> v2(Sin(v1));
    for (int r=0; r<3; r++) for (int c=0; c<3; c++)
      CHECK_EQUAL(sin(v1(r,c)),v2(r,c));

    zMatrix<double> v3,v4;
    v3=+v1;
    v4=-v1;
    for (int r=0; r<3; r++) for (int c=0; c<3; c++)
      CHECK_EQUAL(-v3(r,c),v4(r,c));

  }

  TEST(MATRIX_PRODUCT)
  {
    zMatrix<double> v1(3,3);
    v1(0,0)=1;
    v1(0,1)=1;
    v1(0,2)=2;
    v1(1,0)=1;
    v1(1,1)=3;
    v1(1,2)=1;
    v1(2,0)=4;
    v1(2,1)=1;
    v1(2,2)=1;
    zMatrix<double> v2(3,2);
    v2(0,0)=1;
    v2(0,1)=2;
    v2(1,0)=3;
    v2(1,1)=3;
    v2(2,0)=2;
    v2(2,1)=1;
    zMatrix<double> v3(v1*v2);
    CHECK_EQUAL(3,v3.Size(0));
    CHECK_EQUAL(2,v3.Size(1));
    CHECK_EQUAL(8,v3(0,0));
    CHECK_EQUAL(7,v3(0,1));
    CHECK_EQUAL(12,v3(1,0));
    CHECK_EQUAL(12,v3(1,1));
    CHECK_EQUAL(9,v3(2,0));
    CHECK_EQUAL(12,v3(2,1));
  }

  TEST(SUB_MATRIX)
  {
    zMatrix<double> v1(Ones<double>(1,1));
    zMatrix<double> v2(Zeros<double>(2,3));
    zMatrix<double> v3(Ones<double>(1,2)*2.0);
    zMatrix<double> v4(3,3);
    v4(Colon(0,1),Colon())=v2;
    v4(2,Colon(2,2))=v1;
    v4(Colon(2,2),Colon(0,1))=v3;
    CHECK_EQUAL(0,v4(0,0));
    CHECK_EQUAL(0,v4(0,1));
    CHECK_EQUAL(0,v4(0,2));
    CHECK_EQUAL(0,v4(1,0));
    CHECK_EQUAL(0,v4(1,1));
    CHECK_EQUAL(0,v4(1,2));
    CHECK_EQUAL(2,v4(2,0));
    CHECK_EQUAL(2,v4(2,1));
    CHECK_EQUAL(1,v4(2,2));

    zMatrix<double> v5(Ones<double>(10,10)(Colon(3,5),Colon(6,9)));
    CHECK_EQUAL(3,v5.Size(0));
    CHECK_EQUAL(4,v5.Size(1));
    for (int r=0; r<3; r++) for (int c=0; c<4; c++)
      CHECK_EQUAL(1,v5(r,c));

    zMatrix<double> v6(Ones<double>(10,10)(3,Colon(0,9)));
    CHECK_EQUAL(1,v6.Size(0));
    CHECK_EQUAL(10,v6.Size(1));
    for (int r=0; r<1; r++) for (int c=0; c<10; c++)
      CHECK_EQUAL(1,v6(r,c));
  }

  TEST(FANCY_SUB_MATRIX)
  {
    //block assign
    zMatrix<double> v(3,3);
    v(0,Colon())=Trans(Colond(1.0, 2.0, 0.5));
    v(1,Colon())=Trans(Colond(3.0, 7.0, 2.0));
    v(2,Colon())=Trans(Colond(100.0, 80.0, -10.0));

    //reverse
    zMatrix<double> v2(3,3);
    v2=v(Colon(2,0,-1), Colon(2,0,-1));
    for (int r=0; r<3; r++) for (int c=0; c<3; c++)
      CHECK_EQUAL(v(r,c),v2(2-r,2-c));

    //block copy
    zMatrix<double> v3(Zeros<double>(4,4));
    v3(Colon(0,1),Colon(0,1))=Ones<double>(2,2);
    v3(Colon(2,3),Colon(2,3))=v3(Colon(0,1),Colon(0,1));
    CHECK_EQUAL(1,v3(2,2));
    CHECK_EQUAL(1,v3(2,3));
    CHECK_EQUAL(1,v3(3,2));
    CHECK_EQUAL(1,v3(3,3));
  }

  TEST(TRANSPOSE)
  {
    zMatrix<double> v1(2,3);
    v1(0,0)=1;
    v1(0,1)=2;
    v1(0,2)=3;
    v1(1,0)=4;
    v1(1,1)=5;
    v1(1,2)=6;
    zMatrix<double> v2(Trans(v1));
    CHECK_EQUAL(1,v2(0,0));
    CHECK_EQUAL(2,v2(1,0));
    CHECK_EQUAL(3,v2(2,0));
    CHECK_EQUAL(4,v2(0,1));
    CHECK_EQUAL(5,v2(1,1));
    CHECK_EQUAL(6,v2(2,1));
  }

  TEST(ZVECTOR)
  {
    zVector<double> x(10);
    for (int i=0; i<10; i++)
      x(i)=i;
    for (int r=0; r<10; r++) for (int c=0; c<1; c++)
      CHECK_EQUAL(r,x(r,c));
    
    zVector<double> y(x);
    x(0)=100;
    CHECK_EQUAL(0,y(0));
  }

  TEST(MatrixOpe)
  {
    zMatrix<double> a(Ones<double>(2,3));
    CHECK_EQUAL(6,Sum(a));
    CHECK_EQUAL(6,Dot(a,a));
    CHECK_EQUAL(Sqrt<double>(6),Norm(a));
  }

  TEST(MatrixCombine)
  {
    zMatrix<double> a;
    a=(Ones<double>(1,3),Zeros<double>(1,2),Dress(100.0));
    CHECK_EQUAL(1,a.Size(0));
    CHECK_EQUAL(6,a.Size(1));
    CHECK_EQUAL(1,a(0,0));
    CHECK_EQUAL(1,a(0,1));
    CHECK_EQUAL(1,a(0,2));
    CHECK_EQUAL(0,a(0,3));
    CHECK_EQUAL(0,a(0,4));
    CHECK_EQUAL(100,a(0,5));

    a=(Trans(Colond(1.0,3.0)) % Zeros<double>(1,3));
    CHECK_EQUAL(2,a.Size(0));
    CHECK_EQUAL(3,a.Size(1));
    CHECK_EQUAL(1,a(0,0));
    CHECK_EQUAL(2,a(0,1));
    CHECK_EQUAL(3,a(0,2));
    CHECK_EQUAL(0,a(1,0));
    CHECK_EQUAL(0,a(1,1));
    CHECK_EQUAL(0,a(1,2));
  }

  TEST(MatrixPow)
  {
    zMatrixd mat(Colond(1,3)%Colond(2,4));
    zMatrixd mat2(mat^2.0);
    for (zuint r=0; r<mat2.Size(0); r++) for (zuint c=0; c<mat2.Size(1); c++)
      CHECK_EQUAL(Pow(mat(r,c),2.0),mat2(r,c));
    zMatrixd mat3(2.0^mat);
    for (zuint r=0; r<mat3.Size(0); r++) for (zuint c=0; c<mat3.Size(1); c++)
      CHECK_EQUAL(Pow(2.0,mat(r,c)),mat3(r,c));
    zMatrixd mat4(mat^mat);
    for (zuint r=0; r<mat4.Size(0); r++) for (zuint c=0; c<mat4.Size(1); c++)
      CHECK_EQUAL(Pow(mat(r,c),mat(r,c)),mat4(r,c));
  }

  TEST(MatrixGradient)
  {
    zMatrixd mat(3,4);
    mat(0,0)=1;  mat(0,1)=2;  mat(0,2)=5;  mat(0,3)=4;
    mat(1,0)=2;  mat(1,1)=2;  mat(1,2)=3;  mat(1,3)=5;
    mat(2,0)=6;  mat(2,1)=6;  mat(2,2)=6;  mat(2,3)=8;
    
    zMatrixd gradx(GradientX(mat));
    CHECK_EQUAL(1,gradx(0,0));
    CHECK_EQUAL(2,gradx(0,1));
    CHECK_EQUAL(1,gradx(0,2));
    CHECK_EQUAL(-1,gradx(0,3));
    CHECK_EQUAL(0,gradx(1,0));
    CHECK_EQUAL(0.5,gradx(1,1));
    CHECK_EQUAL(1.5,gradx(1,2));
    CHECK_EQUAL(2,gradx(1,3));
    CHECK_EQUAL(0,gradx(2,0));
    CHECK_EQUAL(0,gradx(2,1));
    CHECK_EQUAL(1,gradx(2,2));
    CHECK_EQUAL(2,gradx(2,3));

    zMatrixd grady(GradientY(mat));
    CHECK_EQUAL(1,grady(0,0));
    CHECK_EQUAL(0,grady(0,1));
    CHECK_EQUAL(-2,grady(0,2));
    CHECK_EQUAL(1,grady(0,3));
    CHECK_EQUAL(2.5,grady(1,0));
    CHECK_EQUAL(2,grady(1,1));
    CHECK_EQUAL(0.5,grady(1,2));
    CHECK_EQUAL(2,grady(1,3));
    CHECK_EQUAL(4,grady(2,0));
    CHECK_EQUAL(4,grady(2,1));
    CHECK_EQUAL(3,grady(2,2));
    CHECK_EQUAL(3,grady(2,3));
  }

  TEST(MatrixReshape)
  {
    zMatrix<double> A(1,9);
    for (zuint i=0; i<9; i++) A(0,i)=i;
    zMatrix<double> B(Reshape(A,3,3));
    CHECK_EQUAL(0,B(0,0));
    CHECK_EQUAL(1,B(1,0));
    CHECK_EQUAL(2,B(2,0));
    CHECK_EQUAL(3,B(0,1));
    CHECK_EQUAL(4,B(1,1));
    CHECK_EQUAL(5,B(2,1));
    CHECK_EQUAL(6,B(0,2));
    CHECK_EQUAL(7,B(1,2));
    CHECK_EQUAL(8,B(2,2));
  }
}